#!/usr/bin/env python
# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Qiming Sun <osirpt.sun@gmail.com>
#


import sys
import numpy
from pyscf.lib import logger
from pyscf import gto
from pyscf.gto.basis import _format_basis_name
from pyscf import ao2mo
from pyscf.data import elements
from pyscf.lib.exceptions import BasisNotFoundError
from pyscf.df.autoaux import autoaux, autoabs
from pyscf import __config__

DFBASIS = getattr(__config__, 'df_addons_aug_etb_beta', 'weigend')
ETB_BETA = getattr(__config__, 'df_addons_aug_dfbasis', 2.0)
FIRST_ETB_ELEMENT = getattr(__config__, 'df_addons_aug_start_at', 36)  # 'Rb'

# TODO: Switch to other default scheme for auxiliary basis generation.
# The auxiliary basis set generated by version 2.6 (and earlier) lacks compact
# functions. It may cause higher errors in ERI integrals.
USE_VERSION_26_AUXBASIS = True

# Obtained from http://www.psicode.org/psi4manual/master/basissets_byfamily.html
DEFAULT_AUXBASIS = {
    # AO basis       JK-fit                     MP2-fit
    'ccpvdz'      : ('cc-pvdz-jkfit'          , 'cc-pvdz-ri'         ),
    'augccpvdz'   : ('aug-cc-pvdz-jkfit'      , 'aug-cc-pvdz-ri'     ),
    'ccpvtz'      : ('cc-pvtz-jkfit'          , 'cc-pvtz-ri'         ),
    'augccpvtz'   : ('aug-cc-pvtz-jkfit'      , 'aug-cc-pvtz-ri'     ),
    'ccpvqz'      : ('cc-pvqz-jkfit'          , 'cc-pvqz-ri'         ),
    'augccpvqz'   : ('aug-cc-pvqz-jkfit'      , 'aug-cc-pvqz-ri'     ),
    'ccpv5z'      : ('cc-pv5z-jkfit'          , 'cc-pv5z-ri'         ),
    'augccpv5z'   : ('aug-cc-pv5z-jkfit'      , 'aug-cc-pv5z-ri'     ),
    'def2svp'     : ('def2-svp-jkfit'         , 'def2-svp-ri'        ),
    'def2svpd'    : ('def2-svp-jkfit'         , 'def2-svpd-ri'       ),
    'def2tzvp'    : ('def2-tzvp-jkfit'        , 'def2-tzvp-ri'       ),
    'def2mtzvp'   : ('def2-tzvp-jkfit'        , 'def2-tzvp-ri'       ),
    'def2tzvpd'   : ('def2-tzvp-jkfit'        , 'def2-tzvpd-ri'      ),
    'def2tzvpp'   : ('def2-tzvpp-jkfit'       , 'def2-tzvpp-ri'      ),
    'def2mtzvpp'  : ('def2-tzvpp-jkfit'       , 'def2-tzvpp-ri'      ),
    'def2tzvppd'  : ('def2-tzvpp-jkfit'       , 'def2-tzvppd-ri'     ),
    'def2qzvp'    : ('def2-qzvp-jkfit'        , 'def2-qzvp-ri'       ),
    #'def2qzvpd'   : ('def2-qzvp-jkfit'        , None                 ),
    'def2qzvpp'   : ('def2-qzvpp-jkfit'       , 'def2-qzvpp-ri'      ),
    'def2qzvppd'  : ('def2-qzvpp-jkfit'       , 'def2-qzvppd-ri'     ),
    'sto3g'       : ('def2-svp-jkfit'         , 'def2-svp-ri'        ),
    '321g'        : ('def2-svp-jkfit'         , 'def2-svp-ri'        ),
    '631g'        : ('cc-pvdz-jkfit'          , 'cc-pvdz-ri'         ),
    '631+g'       : ('heavy-aug-cc-pvdz-jkfit', 'heavyaug-cc-pvdz-ri'),
    '631++g'      : ('aug-cc-pvdz-jkfit'      , 'aug-cc-pvdz-ri'     ),
    '6311g'       : ('cc-pvtz-jkfit'          , 'cc-pvtz-ri'         ),
    '6311+g'      : ('heavy-aug-cc-pvtz-jkfit', 'heavyaug-cc-pvtz-ri'),
    '6311++g'     : ('aug-cc-pvtz-jkfit'      , 'aug-cc-pvtz-ri'     ),
}

class load(ao2mo.load):
    '''load 3c2e integrals from hdf5 file. It can be used in the context
    manager:

    with load(cderifile) as eri:
        print(eri.shape)
    '''
    def __init__(self, eri, dataname='j3c'):
        ao2mo.load.__init__(self, eri, dataname)

def _aug_etb_element(nuc_charge, basis, beta):
    l_max = max(b[0] for b in basis)
    emin_by_l = [1e99] * (l_max+1)
    emax_by_l = [0] * (l_max+1)
    for b in basis:
        l = b[0]
        if isinstance(b[1], (int, numpy.integer)):
            e_c = numpy.array(b[2:])
        else:
            e_c = numpy.array(b[1:])
        es = e_c[:,0]
        cs = e_c[:,1:]
        es = es[abs(cs).max(axis=1) > 1e-3]
        emax_by_l[l] = max(es.max(), emax_by_l[l])
        emin_by_l[l] = min(es.min(), emin_by_l[l])

    conf = elements.CONFIGURATION[nuc_charge]
    # 1: H - Be, 2: B - Ca, 3: Sc - La, 4: Ce -
    max_shells = 4 - conf.count(0)

    if USE_VERSION_26_AUXBASIS:
        # This is the method that version 2.6 (and earlier) generates auxiliary
        # basis. It estimates the exponents ranges by geometric average.
        # This method is not recommended because it tends to generate diffuse
        # functions. Important compact functions might be improperly excluded.
        l_max = min(l_max, max_shells)
        l_max_aux = l_max * 2
        l_max1 = l_max + 1
        emin_by_l = numpy.array(emin_by_l[:l_max1])
        emax_by_l = numpy.array(emax_by_l[:l_max1])
        emax = (emax_by_l[:,None] * emax_by_l) ** .5 * 2
        emin = (emin_by_l[:,None] * emin_by_l) ** .5 * 2
    else:
        # Using normal average, more auxiliary functions, especially compact
        # functions, will be generated.
        l_max_aux = min(l_max, max_shells) * 2
        l_max1 = l_max + 1
        emin_by_l = numpy.array(emin_by_l)
        emax_by_l = numpy.array(emax_by_l)
        emax = emax_by_l[:,None] + emax_by_l
        emin = emin_by_l[:,None] + emin_by_l

    liljsum = numpy.arange(l_max1)[:,None] + numpy.arange(l_max1)
    emax_by_l = numpy.array([emax[liljsum==ll].max() for ll in range(l_max_aux+1)])
    emin_by_l = numpy.array([emin[liljsum==ll].min() for ll in range(l_max_aux+1)])

    ns = numpy.log((emax_by_l+emin_by_l)/emin_by_l) / numpy.log(beta)
    etb = []
    for l, n in enumerate(numpy.ceil(ns).astype(int)):
        if n > 0:
            etb.append((l, n, emin_by_l[l], beta))
    return etb

def aug_etb_for_dfbasis(mol, dfbasis=DFBASIS, beta=ETB_BETA,
                        start_at=FIRST_ETB_ELEMENT):
    '''augment weigend basis with even-tempered gaussian basis
    exps = alpha*beta^i for i = 1..N
    '''
    nuc_start = gto.charge(start_at)
    uniq_atoms = {a[0] for a in mol._atom}

    newbasis = {}
    for symb in uniq_atoms:
        nuc_charge = gto.charge(symb)
        if nuc_charge < nuc_start:
            newbasis[symb] = dfbasis
        else:
            basis = mol._basis[symb]
            etb = _aug_etb_element(nuc_charge, basis, beta)
            if etb:
                newbasis[symb] = gto.expand_etbs(etb)
                for l, n, emin, beta in etb:
                    logger.info(mol, 'ETB for %s: l = %d, exps = %s * %g^n , n = 0..%d',
                                symb, l, emin, beta, n-1)
            else:
                raise RuntimeError(f'Failed to generate even-tempered auxbasis for {symb}')

    return newbasis

def aug_etb(mol, beta=ETB_BETA):
    '''To generate the even-tempered auxiliary Gaussian basis'''
    return aug_etb_for_dfbasis(mol, beta=beta, start_at=0)

def make_auxbasis(mol, *, xc='HF', mp2fit=False):
    '''Depending on the orbital basis, generating even-tempered Gaussians or
    the optimized auxiliary basis defined in DEFAULT_AUXBASIS
    '''
    uniq_atoms = {a[0] for a in mol._atom}
    auxbasis = {}
    if isinstance(mol.basis, str):
        _basis = {a: mol.basis for a in uniq_atoms}
        default_auxbasis = predefined_auxbasis(mol, mol.basis, xc, mp2fit)
        if default_auxbasis:
            auxbasis = {a: default_auxbasis for a in uniq_atoms}
            logger.debug(mol, 'Default auxbasis %s is applied universally', default_auxbasis)
    elif isinstance(mol.basis, dict):
        if 'default' in mol.basis:
            default_basis = mol.basis['default']
            _basis = {a: default_basis for a in uniq_atoms}
            _basis.update(mol.basis)
            del _basis['default']
        else:
            _basis = mol.basis
    else:
        _basis = mol._basis or {}

    for k, obs in _basis.items():
        if not isinstance(obs, str):
            continue
        if k in auxbasis:
            auxb = auxbasis[k]
            try:
                # Test if basis auxb for element k is available
                gto.basis.load(auxb, elements._std_symbol_without_ghost(k))
            except BasisNotFoundError:
                del auxbasis[k]
        else:
            balias = _format_basis_name(obs)
            if gto.basis._is_pople_basis(balias):
                balias = balias.split('g')[0] + 'g'
            auxb = predefined_auxbasis(mol, balias, xc, mp2fit)
            if auxb is not None:
                try:
                    gto.basis.load(auxb, elements._std_symbol_without_ghost(k))
                except BasisNotFoundError:
                    continue
                auxbasis[k] = auxb
                logger.info(mol, 'Default auxbasis %s is used for %s %s',
                            auxb, k, obs)

    if len(auxbasis) != len(_basis):
        # Some AO basis not found in DEFAULT_AUXBASIS
        auxbasis, auxdefault = aug_etb(mol), auxbasis
        auxbasis.update(auxdefault)
        aux_etb = set(auxbasis) - set(auxdefault)
        if aux_etb:
            logger.warn(mol, 'Even tempered Gaussians are generated as '
                        'DF auxbasis for  %s', ' '.join(aux_etb))
            for k in aux_etb:
                logger.debug(mol, '  ETB auxbasis for %s  %s', k, auxbasis[k])
    return auxbasis

# TODO: add auxbasis keyword etb and auto
def make_auxmol(mol, auxbasis=None):
    '''Generate a fake Mole object which uses the density fitting auxbasis as
    the basis sets.  If auxbasis is not specified, the optimized auxiliary fitting
    basis set will be generated according to the rules recorded in
    pyscf.df.addons.DEFAULT_AUXBASIS.  If the optimized auxiliary basis is not
    available (either not specified in DEFAULT_AUXBASIS or the basis set of the
    required elements not defined in the optimized auxiliary basis),
    even-tempered Gaussian basis set will be generated.

    See also the paper JCTC, 13, 554 about generating auxiliary fitting basis.

    Kwargs:
        auxbasis : str, list, tuple
            Similar to the input of orbital basis in Mole object.
    '''
    pmol = mol.copy(deep=False)

    if auxbasis is None:
        auxbasis = make_auxbasis(mol)
    pmol.basis = auxbasis

    if isinstance(auxbasis, (str, list, tuple)):
        uniq_atoms = {a[0] for a in mol._atom}
        _basis = {a: auxbasis for a in uniq_atoms}
    else:
        assert isinstance(auxbasis, dict)
        if 'default' in auxbasis:
            uniq_atoms = {a[0] for a in mol._atom}
            _basis = {a: auxbasis['default'] for a in uniq_atoms}
            _basis.update(auxbasis)
            del (_basis['default'])
        else:
            _basis = auxbasis

    try:
        pmol._basis = pmol.format_basis(_basis)
    except BasisNotFoundError:
        if isinstance(auxbasis, str):
            print(f'''
Some elements are not found in the specified auxiliary basis set {auxbasis}.
To proceed, you can generate even-tempered Gaussian (ETB) functions for the missing elements.
Three routines are available for this system:
1. PySCF's built-in basis generation function
    mf.with_df.auxbasis = pyscf.df.make_auxbasis(mol)
2. ORCA recommended AutoAux ETB
    mf.with_df.auxbasis = pyscf.df.autoaux(mol)
3. Gaussian package recommended ETB
    mf.with_df.auxbasis = pyscf.df.autoabs(mol)
''')
        raise

    # Note: To pass parameters like gauge origin, rsh-omega to auxmol,
    # mol._env[:PTR_ENV_START] must be copied to auxmol._env
    pmol._atm, pmol._bas, pmol._env = \
            pmol.make_env(mol._atom, pmol._basis, mol._env[:gto.PTR_ENV_START])
    pmol._built = True
    logger.debug(mol, 'num shells = %d, num cGTOs = %d',
                 pmol.nbas, pmol.nao_nr())
    return pmol

def bse_predefined_auxbasis(mol, basis, xc='HF', mp2fit=False):
    '''Find auxiliary basis sets for XC functionals from BSE database.
    If no matching basis set is found, the function returns None.
    '''
    if not isinstance(basis, str):
        return None

    try:
        from pyscf.dft.libxc import is_hybrid_xc
    except ImportError:
        from pyscf.dft.xcfun import is_hybrid_xc
    pyscf_basis_alias = _format_basis_name(basis).lower()
    basis_meta = gto.mole.BSE_META.get(pyscf_basis_alias)
    auxbasis = None
    if basis_meta:
        auxiliaries = basis_meta[2]
        if mp2fit:
            auxbasis = auxiliaries.get('rifit')
            if auxbasis:
                logger.debug(mol, f'BSE predefined RIFIT basis set {auxbasis} for mp2fit')
        elif is_hybrid_xc(xc):
            auxbasis = auxiliaries.get('jkfit')
            if auxbasis:
                logger.debug(mol, f'BSE predefined JKFIT basis set {auxbasis} for {xc}')
        else:
            auxbasis = auxiliaries.get('jfit')
            if auxbasis is None:
                auxbasis = auxiliaries.get('dftjfit')
            if auxbasis is None:
                auxbasis = auxiliaries.get('jkfit')
            if auxbasis:
                logger.debug(mol, f'BSE predefined JFIT basis set {auxbasis} for {xc}')
    return auxbasis

def predefined_auxbasis(mol, basis, xc='HF', mp2fit=False):
    '''Predefined auxiliary basis sets for the specified orbital basis set and
    XC functional. The searching starts from the Psi4 recommendation. If not
    found, the record in BSE database will be used. If no matching basis set is
    found, the function returns None.
    '''
    if not isinstance(basis, str):
        return None

    try:
        from pyscf.dft.libxc import is_hybrid_xc
    except ImportError:
        from pyscf.dft.xcfun import is_hybrid_xc
    pyscf_basis_alias = _format_basis_name(basis).lower()
    if pyscf_basis_alias in DEFAULT_AUXBASIS:
        if mp2fit:
            auxbasis = DEFAULT_AUXBASIS[pyscf_basis_alias][1]
            logger.debug(mol, f'Psi4 predefined RIFIT basis set {auxbasis} for mp2fit')
            return auxbasis
        elif is_hybrid_xc(xc):
            auxbasis = DEFAULT_AUXBASIS[pyscf_basis_alias][0]
            logger.debug(mol, f'Psi4 predefined JKFIT basis set {auxbasis} for {xc}')
            return auxbasis
    return bse_predefined_auxbasis(mol, basis, xc, mp2fit)

del (DFBASIS, ETB_BETA, FIRST_ETB_ELEMENT)
